# -*- coding: utf-8 -*-
"""
Created on Thu Jun 13 17:07:14 2019

@author: jlani
"""

import numpy as np

class Nested_Crra_UM:
    def __init__(self, discount_rate, crra_time, crra_risk):
        self.dr = discount_rate
        self.crra_time = crra_time
        self.crra_risk = crra_risk
        
        penalty = 0.0 #Use this to penalize values that are outside the acceptable range.
        
        #Make sure passed acceptable values
        if self.dr < 0.0:
            penalty = penalty - self.dr
            self.dr = 0.0
        if self.dr > 1.0:
            penalty = penalty + ( self.dr - 1.0 )
            self.dr = 1.0
        if self.crra_time < 0.0:
            penalty = penalty -  self.crra_time
            self.crra_time = 0.0
        if self.crra_risk < 0.0:
            penalty = penalty -  self.crra_risk
            self.crra_risk = 0.0
        if penalty > 0.:
            pass
            #print( 'discount rate: {0}, time: {1}, riskk {2}'.format( discount_rate, crra_time, crra_risk ) )
        
        self.penalty = penalty

    def get_time_measure(self,T):
        '''Returns the vector representing the measure on time.'''
        dr_sum = sum( self.dr**t for t in range(T) )
        return np.array( [ self.dr**t / dr_sum for t in range(T) ] )

    def U_time(self,c):
        '''Utility for the stream of consumption c.'''
        T = len(c)
        time_measure = self.get_time_measure(T)
        crrat = self.crra_time
        #Make sure no 0 entries in c if crrat >= 1
        if (crrat >= 1.) & ( c <= 0. ).any():
            return 0
        if crrat == 1.:
            return np.exp( time_measure @ np.log(c) )
        else:
            return ( time_measure @ ( c**( 1.- crrat ) ) )**( 1. / (1. - crrat) )

    def U_risk(self, u_time, measure):
        '''u_time is a vector of utility for each state coming from the U_time function.'''
        crrar = self.crra_risk
        #Make sure no 0 entries in u_time if crrar >= 1
        if (crrar >= 1.) & (u_time <= 0.).any():
            return 0.
        if crrar == 1.:
            return np.exp( measure @ np.log(u_time) )
        else:
            return ( measure @ ( u_time**( 1. - crrar ) ) )**( 1. / (1. - crrar) )

    def U(self,c,measure):
        '''c is an S-list of time streams. Measure has S elements.'''
        u_time = np.array([ self.U_time( c_stream ) for c_stream in c ])
        return self.U_risk( u_time, measure )
    
    def v_time(self,p):
        '''Indirect utility for price p and income 1.'''
        T = len(p)
        time_measure = self.get_time_measure(T)
        crrat = self.crra_time
        if crrat == 0.:
            return max( time_measure / p )
        elif crrat == 1.:
            #Use this time_measure2 hack to avoid taking a log of 0. It doesnt matter in the end cuz I multiply the expression by 0.
            time_measure2 = time_measure.copy()
            time_measure2[ time_measure <= 0. ] = 1.
            return np.exp( time_measure @ np.log( time_measure2 / p ) )
        else:
            #print('crrat {0}'.format(crrat))
            return ( p @ ( ( time_measure / p )**( 1. / crrat ) ) )**( crrat / ( 1. - crrat ) )
            
    def e_risk(self, v_time, measure):
        '''v_time is a vector of utility for each state coming from the v_time function. Return the per-unit cost of utility.'''
        
        crrar = self.crra_risk
        if (crrar >= 1.) & ( v_time <= 0. ).any():
            #special case for 0 utility adn crrar >= 1.
            return 0.
        if crrar == 0.:
            return max( measure * v_time )**(-1.)
        elif crrar == 1.:
            return np.exp( - measure @ np.log( measure * v_time ) )
        else:
            #print(crrar)
            return ( (measure ** (1. / crrar)) @ ( v_time**( (1. - crrar) / crrar ) ) )**( - crrar / ( 1. - crrar ) )
         
    def e_fraq( self, p,c,measure ):
        '''p and c are lists of length S. Each element of the lists is a T-vector. Measure is an S-vector.'''
        u = self.U(c,measure)
        v_time = np.array( [ self.v_time( p_stream ) for p_stream in p ] )
        expenditure = self.e_risk(v_time, measure)
        spent = c[0] @ p[0] + c[1] @ p[1]
        return (expenditure * u) / spent

def ccei_nested_crra(C, P, discount_rate, crra_time, crra_risk):
    my_UM = Nested_Crra_UM( discount_rate = discount_rate, crra_time = crra_time, crra_risk = crra_risk )
    N = C.shape[0]
    #Make P and C into lists.
    Clist = [ [ c[0:2], c[2:4] ] for c in C ]
    Plist = [ [ p[0:2], p[2:4] ] for p in P ]
    ccei = min( my_UM.e_fraq( p = Plist[n], c = Clist[n], measure = np.array([0.5,0.5]) ) for n in range(N) )
    return ccei + my_UM.penalty

def varian_nested_crra(C, P, discount_rate, crra_time, crra_risk):
    my_UM = Nested_Crra_UM( discount_rate = discount_rate, crra_time = crra_time, crra_risk = crra_risk )
    N = C.shape[0]
    #Make P and C into lists.
    Clist = [ [ c[0:2], c[2:4] ] for c in C ]
    Plist = [ [ p[0:2], p[2:4] ] for p in P ]
    varian = (1. - ( sum( ( 1. - my_UM.e_fraq( p = Plist[n], c = Clist[n], measure = np.array([0.5,0.5]) ) )**2 for n in range(N) ) / N ) ) ** (0.5)
    return varian + my_UM.penalty





